{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# for real data\n",
    "# create dataset of number sequences\n",
    "# let's assume that we have a vocabulary size of 1000 words\n",
    "# let's assume that 0 is the EOS token, and 1 is the SOS token, and 2 is PAD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def toData(batch):\n",
    "    # [input] batch: list of strings\n",
    "    # [output] input_out, output_out: np array([b x seq]), fixed size, eos & zero padding applied\n",
    "    # [output] in_idx, out_idx: np.array([b]), length of each line in seq\n",
    "    batch = [line.replace('\\n','') for line in batch]\n",
    "    inputs_ = []\n",
    "    outputs_ = []\n",
    "    in_len = []\n",
    "    out_len = []\n",
    "    for line in batch:\n",
    "        inputs, outputs = line.split('::')\n",
    "#         outputs, inputs = line.split('::')\n",
    "        inputs_.append([int(num) for num in inputs.split(' ')])\n",
    "        outputs_.append([int(num) for num in outputs.split(' ')])\n",
    "        in_len.append(len(inputs_[-1]))\n",
    "        out_len.append(len(outputs_[-1]))\n",
    "    in_len = np.array(in_len)\n",
    "    out_len = np.array(out_len)\n",
    "    max_in = max(in_len)\n",
    "    max_out = max(out_len)\n",
    "    batch_size = len(batch)\n",
    "    input_out = np.zeros([batch_size,max_in],dtype=int)\n",
    "    output_out = np.zeros([batch_size,max_out],dtype=int)\n",
    "    for b in range(batch_size):\n",
    "        input_out[b][:in_len[b]] = np.array(inputs_[b])\n",
    "        output_out[b][:out_len[b]] = np.array(outputs_[b])\n",
    "    out_rev = out_len.argsort()[::-1]\n",
    "    return input_out[out_rev], output_out[out_rev], in_len[out_rev], out_len[out_rev]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7667\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "w2i = np.load('data/en-django/en-django/w2i.npy').item()\n",
    "i2w = np.load('data/en-django/en-django/i2w.npy').item()\n",
    "vocab_size = len(w2i)\n",
    "print(vocab_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import torch.optim as optim\n",
    "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "# from models.copynet import CopyEncoder, CopyDecoder\n",
    "from models.copynet_dbg import CopyEncoder, CopyDecoder\n",
    "from models.functions import numpy_to_var, to_np, to_var, visualize, decoder_initial, update_logger\n",
    "import time\n",
    "import sys\n",
    "import math\n",
    "torch.manual_seed(1000)\n",
    "\n",
    "# Hyperparameters\n",
    "embed_size = 150\n",
    "hidden_size = 300\n",
    "num_layers = 1\n",
    "bin_size = 10\n",
    "num_epochs = 40\n",
    "prev_end=0\n",
    "batch_size = 16\n",
    "lr = 0.001\n",
    "vocab_size = 100\n",
    "weight_decay = 0.99\n",
    "use_saved = True # whether to train from a previous model\n",
    "continue_from = 7\n",
    "# version = 'django'\n",
    "version = 'django_fixed'\n",
    "step = 0 # number of steps taken\n",
    "\n",
    "# input and output directories\n",
    "w2i = np.load('data/en-django/en-django/w2i.npy').item()\n",
    "i2w = np.load('data/en-django/en-django/i2w.npy').item()\n",
    "vocab_size = len(w2i)\n",
    "# file_dir = 'data/en-django/en-django/idx_lists.txt'\n",
    "file_dir = 'data/en-django/en-django/idx_lists_fixed.txt'\n",
    "# get training and test data\n",
    "with open(file_dir) as f:\n",
    "    lines = f.readlines()\n",
    "\n",
    "import random\n",
    "random.shuffle(lines)\n",
    "test = lines[:200]\n",
    "train = lines[200:]\n",
    "\n",
    "# get number of batches\n",
    "num_samples = len(train)\n",
    "num_batches = int(num_samples/batch_size)\n",
    "\n",
    "################ load copynet model #####################\n",
    "if use_saved:\n",
    "    # if using from previous data\n",
    "    encoder_dir = 'models/encoder_%s_%s.pckl' % (version,continue_from)\n",
    "    decoder_dir = 'models/decoder_%s_%s.pckl' % (version,continue_from)\n",
    "    encoder = torch.load(f=encoder_dir)\n",
    "    decoder = torch.load(f=decoder_dir)\n",
    "else:\n",
    "    encoder = CopyEncoder(vocab_size, embed_size, hidden_size)\n",
    "    decoder = CopyDecoder(vocab_size, embed_size, hidden_size)\n",
    "    continue_from = 0\n",
    "if torch.cuda.is_available():\n",
    "    encoder.cuda()\n",
    "    decoder.cuda()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "help(encoder.load_state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "################################# training ##################################\n",
    "\n",
    "# set loss\n",
    "criterion = nn.NLLLoss()\n",
    "\n",
    "start = time.time()\n",
    "for epoch in range(num_epochs):\n",
    "    print(\"==================================================\")\n",
    "    print(\"Epoch \",epoch+1)\n",
    "    opt_e = optim.Adam(params=encoder.parameters(), lr=lr)\n",
    "    opt_d = optim.Adam(params=decoder.parameters(), lr=lr)\n",
    "    lr= lr * weight_decay # weight decay\n",
    "    # shuffle data\n",
    "    random.shuffle(train)\n",
    "    samples_read = 0\n",
    "    while(samples_read<len(train)):\n",
    "        # initialize gradient buffers\n",
    "        opt_e.zero_grad()\n",
    "        opt_d.zero_grad()\n",
    "\n",
    "        # obtain batch outputs\n",
    "        batch = train[samples_read:min(samples_read+batch_size,len(train))]\n",
    "        annotations, codes, in_len, out_len = toData(batch)\n",
    "        output_out = output_out[:,:50]\n",
    "        out_len = np.array([min(50,x) for x in out_len])\n",
    "#         print(in_len.shape)\n",
    "#         print(out_len.shape)\n",
    "        samples_read+=len(batch)\n",
    "\n",
    "        # mask input to remove padding\n",
    "        input_mask = np.array(input_out>0, dtype=int)\n",
    "\n",
    "        # input and output in Variable form\n",
    "#         x = numpy_to_var(input_out)\n",
    "#         y = numpy_to_var(output_out)\n",
    "        x = numpy_to_var(codes)\n",
    "        y = numpy_to_var(annotations)\n",
    "\n",
    "        # apply to encoder\n",
    "        encoded, _ = encoder(x)\n",
    "\n",
    "        # get initial input of decoder\n",
    "        decoder_in, s, w = decoder_initial(x.size(0))\n",
    "\n",
    "        # out_list to store outputs\n",
    "        out_list=[]\n",
    "        for j in range(y.size(1)): # for all sequences\n",
    "            \"\"\"\n",
    "            decoder_in (Variable): [b]\n",
    "            encoded (Variable): [b x seq x hid]\n",
    "            input_out (np.array): [b x seq]\n",
    "            s (Variable): [b x hid]\n",
    "            \"\"\"\n",
    "            # 1st state\n",
    "#             print(j)\n",
    "#             print(out.size())\n",
    "            if j==0:\n",
    "                out, s, w = decoder(input_idx=decoder_in, encoded=encoded,\n",
    "                                encoded_idx=input_out, prev_state=s,\n",
    "                                weighted=w, order=j)\n",
    "            # remaining states\n",
    "            else:\n",
    "                tmp_out, s, w = decoder(input_idx=decoder_in, encoded=encoded,\n",
    "                                encoded_idx=input_out, prev_state=s,\n",
    "                                weighted=w, order=j)\n",
    "                out = torch.cat([out,tmp_out],dim=1)\n",
    "            # for debugging: stop if nan\n",
    "            if math.isnan(w[-1][0][0].data[0]):\n",
    "                sys.exit()\n",
    "            # select next input\n",
    "\n",
    " \n",
    "            decoder_in = y[:,j] # train with ground truth\n",
    "#             out_list.append(out[:,-1].max(1)[1].squeeze().cpu().data.numpy())\n",
    "\n",
    "        # print(torch.stack(decoder.prob_c_to_g,1))\n",
    "        target = pack_padded_sequence(y,out_len.tolist(), batch_first=True)[0]\n",
    "        pad_out = pack_padded_sequence(out,out_len.tolist(), batch_first=True)[0]\n",
    "        # include log computation as we are using log-softmax and NLL\n",
    "        pad_out = torch.log(pad_out)\n",
    "        loss = criterion(pad_out, target)\n",
    "        loss.backward()\n",
    "        if samples_read%1==0:\n",
    "            print(\"[%d/%d] Loss: %1.4f\"%(samples_read,len(train),loss.data[0]))\n",
    "        opt_e.step()\n",
    "        opt_d.step()\n",
    "        step += 1\n",
    "        info = {\n",
    "            'loss': loss.data[0]\n",
    "        }\n",
    "    # print(\"Loss: \",loss.data[0])\n",
    "    elapsed = time.time()\n",
    "    print(\"Elapsed time for epoch: \",elapsed-start)\n",
    "    start = time.time()\n",
    "\n",
    "    torch.save(f='models/encoder_%s_%s.pckl' % (version,str(epoch+continue_from)),obj=encoder)\n",
    "    torch.save(f='models/decoder_%s_%s.pckl' % (version,str(epoch+continue_from)),obj=decoder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CopyDecoder (\n",
       "  (embed): Embedding(7667, 150)\n",
       "  (gru): GRU(750, 300, batch_first=True)\n",
       "  (Ws): Linear (600 -> 300)\n",
       "  (Wo): Linear (300 -> 7667)\n",
       "  (Wc): Linear (600 -> 300)\n",
       "  (nonlinear): Tanh ()\n",
       ")"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# encoder_debug = encoder\n",
    "# decoder_debug = decoder\n",
    "encoder_debug = CopyEncoder(vocab_size, embed_size, hidden_size)\n",
    "decoder_debug = CopyDecoder(vocab_size, embed_size, hidden_size)\n",
    "encoder_debug.cuda()\n",
    "decoder_debug.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "encoder_debug.load_state_dict(encoder.state_dict())\n",
    "decoder_debug.load_state_dict(decoder.state_dict())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def IdxToWords(idx_list,dic):\n",
    "    return [dic[x] for x in idx_list]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "################################# validation ##################################\n",
    "input_out, output_out, in_len, out_len = toData(test[50:70])\n",
    "input_mask = np.array(input_out>0, dtype=int)\n",
    "x = numpy_to_var(input_out)\n",
    "y = numpy_to_var(output_out)\n",
    "encoded, _ = encoder(x)\n",
    "decoder_in, s, w = decoder_initial(x.size(0))\n",
    "out_list=[]\n",
    "for j in range(y.size(1)): # for all sequences\n",
    "    if j==0:\n",
    "#         out, s, w = decoder(input_idx=decoder_in, encoded=encoded,\n",
    "        out, s, w = decoder_debug(input_idx=decoder_in, encoded=encoded,\n",
    "                        encoded_idx=input_out, prev_state=s, \n",
    "                        weighted=w, order=j)\n",
    "    else:\n",
    "#         tmp_out, s, w = decoder(input_idx=decoder_in, encoded=encoded,\n",
    "        tmp_out, s, w = decoder_debug(input_idx=decoder_in, encoded=encoded,\n",
    "                        encoded_idx=input_out, prev_state=s, \n",
    "                        weighted=w, order=j)\n",
    "        tmp_data = tmp_out.data\n",
    "        tmp_data[:,:,0].zero_()\n",
    "        tmp_out = Variable(tmp_data)\n",
    "        out = torch.cat([out,tmp_out],dim=1)\n",
    "#     decoder_in = y[:,j] # train with ground truth\n",
    "    decoder_in = out[:,-1].max(1)[1].squeeze() # train with sequence outputs\n",
    "    out_list.append(out[:,-1].max(1)[1].squeeze().cpu().data.numpy())\n",
    "# with open(save_dir,'a') as f:\n",
    "#     out = np.hstack(tup=(out,iden))\n",
    "#     f.write('\\n')\n",
    "#     for line in out:\n",
    "#         f.write(','.join([str(y_) for y_ in line])+'\\n')\n",
    "out_list = np.array(out_list).transpose()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": false,
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==============================================================\n",
      "[INPUT]\n",
      "ADDRESS_HEADERS = set ( [ ' from ' , ' sender ' , ' reply - to ' , ' to ' , ' cc ' , ' bcc ' , ' resent - from ' , ' resent - sender ' , ' resent - to ' , ' resent - cc ' , ' resent - bcc ' , ] )\n",
      "[GROUND OUTPUT]\n",
      "ADDRESS_HEADERS is a set containing strings : ' from ' , ' sender ' , ' reply - to ' , ' to ' , ' cc ' , ' bcc ' , ' resent - from ' , ' resent - sender ' ,\n",
      "[PREDICTED OUTPUT]\n",
      "' ' ' , ' resent ' for ' , ' , ' for ' , ' for ' , ' for ' , ' for ' , ' for ' , ' for ' , ' for ' , ' for ' ' ' '\n",
      "==============================================================\n",
      "[INPUT]\n",
      "raise InvalidTemplateLibrary ( \" Unsupported arguments to \" \" Library . filter : ( % r , % r ) \" , ( name , filter_func ) )\n",
      "[GROUND OUTPUT]\n",
      "raise an InvalidTemplateLibrary exception with an argument string ( \" Unsupported arguments to Library . filter : ( % r , % r ) \" ,\n",
      "[PREDICTED OUTPUT]\n",
      "raise an InvalidTemplateLibrary exception with an argument string \" Unsupported arguments % s . \" , where ' % s ' is replaced by :\n",
      "==============================================================\n",
      "[INPUT]\n",
      "@ property\n",
      "[GROUND OUTPUT]\n",
      "where ' % s ' is replaced with self . _ _ class _ _ . _ _ name _ _ .   property decorator ,\n",
      "[PREDICTED OUTPUT]\n",
      "property decorator ,\n",
      "==============================================================\n",
      "[INPUT]\n",
      "IDENTIFIER = re . compile ( ' ^[a - z_][a - z0 - 9_]*$ ' , re . I )\n",
      "[GROUND OUTPUT]\n",
      "compile regex from string ' ^[a - z_][a - z0 - 9_]*$ ' in case insensitive mode , substitute it for IDENTIFIER .\n",
      "[PREDICTED OUTPUT]\n",
      "call the function re . compile with 2 arguments : raw string ' < - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
      "==============================================================\n",
      "[INPUT]\n",
      "if len ( fixture_files_in_dir ) > 1 :\n",
      "[GROUND OUTPUT]\n",
      "and result of the function humanize called with an argument fixture_dir .   if length of fixture_files_in_dir is greater than 1 ,\n",
      "[PREDICTED OUTPUT]\n",
      "if length of fixture_files_in_dir is not equal to integer 1 ,\n",
      "==============================================================\n",
      "[INPUT]\n",
      "raise NotImplementedError ( ' subclasses of Storage must provide a url ( ) method ' )\n",
      "[GROUND OUTPUT]\n",
      "raise an NotImplementedError exception with argument string ' subclasses of Storage must provide a url ( ) method ' .\n",
      "[PREDICTED OUTPUT]\n",
      "raise an NotImplementedError exception with argument string ' subclasses of ' subclasses of provide ( ) ( ) ' .\n",
      "==============================================================\n",
      "[INPUT]\n",
      "max_value = self . max_expr . resolve ( context )\n",
      "[GROUND OUTPUT]\n",
      "call the method self . max_expr . resolve with an argument context , substitute the result for max_value .\n",
      "[PREDICTED OUTPUT]\n",
      "call the method self . resolve with an argument context , substitute the result for max_value .\n",
      "==============================================================\n",
      "[INPUT]\n",
      "raise PageNotAnInteger ( ' That page number is not an integer ' )\n",
      "[GROUND OUTPUT]\n",
      "raise an exception PageNotAnInteger with string ' That page number is not an integer ' as an argument .\n",
      "[PREDICTED OUTPUT]\n",
      "raise an PageNotAnInteger exception with an argument string ' That number number not not not not an not ' .\n",
      "==============================================================\n",
      "[INPUT]\n",
      "size = property ( _ get_size , _ set_size )\n",
      "[GROUND OUTPUT]\n",
      "size is a property object with _ get_size as getter method and _ set_size as setter method .\n",
      "[PREDICTED OUTPUT]\n",
      "get the attribute named _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ , substitute it for _ _ _ _ _ _\n",
      "==============================================================\n",
      "[INPUT]\n",
      "return I18N_MODIFIED if filename . endswith ( ' . mo ' ) else FILE_MODIFIED\n",
      "[GROUND OUTPUT]\n",
      "if filename ends with string ' . mo ' return I18N_MODIFIED , otherwise return FILE_MODIFIED .\n",
      "[PREDICTED OUTPUT]\n",
      "if if if if if if if if if if if if if if if if exists is not equal to a string ' . ' ,\n",
      "==============================================================\n",
      "[INPUT]\n",
      "from django . core . serializers . base import SerializerDoesNotExist\n",
      "[GROUND OUTPUT]\n",
      "from django . core . serializers . base import SerializerDoesNotExist into default name space .\n",
      "[PREDICTED OUTPUT]\n",
      "from django . base . base . base import import into default name space .\n",
      "==============================================================\n",
      "[INPUT]\n",
      "def _ _ repr _ _ ( self ) :\n",
      "[GROUND OUTPUT]\n",
      "define the method _ _ repr _ _ with an argument self .\n",
      "[PREDICTED OUTPUT]\n",
      "define the method _ _ repr _ _ with an argument self .\n",
      "==============================================================\n",
      "[INPUT]\n",
      "class SerializationError ( Exception ) :\n",
      "[GROUND OUTPUT]\n",
      "derive the class SerializationError from the Exception base class .\n",
      "[PREDICTED OUTPUT]\n",
      "derive the class class from the Exception base class .\n",
      "==============================================================\n",
      "[INPUT]\n",
      "data = str ( data )\n",
      "[GROUND OUTPUT]\n",
      "convert data to string , substitute it for data .\n",
      "[PREDICTED OUTPUT]\n",
      "convert data into a string , substitute it for data .\n",
      "==============================================================\n",
      "[INPUT]\n",
      "def rendered_content ( self ) :\n",
      "[GROUND OUTPUT]\n",
      "define the method rendered_content with an argument self .\n",
      "[PREDICTED OUTPUT]\n",
      "define the method rendered_content with an argument self .\n",
      "==============================================================\n",
      "[INPUT]\n",
      "def escape_quotes ( m ) :\n",
      "[GROUND OUTPUT]\n",
      "define the function escape_quotes with an argument m .\n",
      "[PREDICTED OUTPUT]\n",
      "define the function escape_quotes with an argument m .\n",
      "==============================================================\n",
      "[INPUT]\n",
      "except ImportError :\n",
      "[GROUND OUTPUT]\n",
      "if ImportError exception is caught ,\n",
      "[PREDICTED OUTPUT]\n",
      "if ImportError exception is caught ,\n",
      "==============================================================\n",
      "[INPUT]\n",
      "return \" \"\n",
      "[GROUND OUTPUT]\n",
      "return an empty string .\n",
      "[PREDICTED OUTPUT]\n",
      "return an empty string .\n",
      "==============================================================\n",
      "[INPUT]\n",
      "supports_microseconds = False\n",
      "[GROUND OUTPUT]\n",
      "supports_microseconds is boolean False .\n",
      "[PREDICTED OUTPUT]\n",
      "supports_microseconds is boolean False .\n",
      "==============================================================\n",
      "[INPUT]\n",
      "else :\n",
      "[GROUND OUTPUT]\n",
      "if not ,\n",
      "[PREDICTED OUTPUT]\n",
      "if not ,\n"
     ]
    }
   ],
   "source": [
    "for i in range(len(input_out)): # for each sample in batch\n",
    "    print(\"==============================================================\")\n",
    "    i_line = input_out[i]\n",
    "    i_out = [idx for idx in IdxToWords(i_line,i2w) if (idx!='<PAD>') & (idx!='<SOS>') & (idx!='<EOS>')]\n",
    "    print(\"[INPUT]\")\n",
    "    print(' '.join(i_out))\n",
    "    a_line = output_out[i]\n",
    "    a_out = [idx for idx in IdxToWords(a_line,i2w) if (idx!='<PAD>') & (idx!='<SOS>') & (idx!='<EOS>')]\n",
    "    print(\"[GROUND OUTPUT]\")\n",
    "    print(' '.join(a_out))\n",
    "    o_line = out_list[i]\n",
    "    o_out = []\n",
    "    for idx in o_line:\n",
    "        if idx==w2i['<EOS>']:\n",
    "            break\n",
    "        if (idx!=w2i['<PAD>']) & (idx!=w2i['<SOS>']):\n",
    "            o_out.append(idx)\n",
    "    o_out = IdxToWords(o_out,i2w)\n",
    "    print(\"[PREDICTED OUTPUT]\")\n",
    "    print(' '.join(o_out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "for line in output_out:\n",
    "    out = [idx for idx in IdxToWords(line,i2w) if idx!='<PAD>']\n",
    "    print(' '.join(out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "for line in out_list:\n",
    "    out = [idx for idx in IdxToWords(line,i2w) if (idx!='<EOS>')]\n",
    "    print(' '.join(out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "# debug\n",
    "# get a sample input, ground truth, output\n",
    "idx = 3\n",
    "print(\"input: \",x[idx].cpu().data.numpy())\n",
    "print(\"input: \",' '.join(IdxToWords(x[idx].cpu().data.numpy(),i2w)))\n",
    "print(\"truth: \",y[idx].cpu().data.numpy())\n",
    "print(\"truth: \",' '.join(IdxToWords(y[idx].cpu().data.numpy(),i2w)))\n",
    "O = torch.cat(decoder_debug.O,1)\n",
    "out_sample = []\n",
    "for o in O[idx].max(1)[1].cpu().numpy().squeeze():\n",
    "    if o==w2i['<EOS>']:\n",
    "        break\n",
    "    else:\n",
    "        out_sample.append(o)\n",
    "\n",
    "print(\"output: \",out_sample)\n",
    "print(\"output: \",' '.join(IdxToWords(out_sample,i2w)))\n",
    "A = torch.stack(decoder_debug.A,1)\n",
    "A2 = torch.stack(decoder_debug.A2,1)\n",
    "P = torch.stack(decoder_debug.P,1)\n",
    "I = torch.stack(decoder_debug.I,1)\n",
    "# E = torch.stack(decoder_debug.E,1)\n",
    "S = torch.stack(decoder_debug.S,1)\n",
    "W = torch.cat(decoder_debug.W,1)\n",
    "Y = torch.stack(decoder_debug.Y,1)\n",
    "# scores = torch.stack(decoder_debug.scores,1)\n",
    "sc = torch.stack(decoder_debug.sc,1)\n",
    "pcg = torch.stack(decoder_debug.prob_c_to_g,1)\n",
    "pc = P[:,:,vocab_size:]\n",
    "pg = P[:,:,:vocab_size]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "visualize(Variable(pc[idx][6:12]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "out_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "b = a.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "20\n",
      "40\n",
      "60\n",
      "80\n",
      "100\n",
      "120\n",
      "140\n",
      "160\n",
      "180\n",
      "200\n"
     ]
    }
   ],
   "source": [
    "batch_count = 0\n",
    "print_list = []\n",
    "while batch_count<len(test):\n",
    "    ################################# validation ##################################\n",
    "    input_out, output_out, in_len, out_len = toData(test[batch_count:min(batch_count+20,len(test))])\n",
    "    batch_count+=20\n",
    "    print(batch_count)\n",
    "    input_mask = np.array(input_out>0, dtype=int)\n",
    "    x = numpy_to_var(input_out)\n",
    "    y = numpy_to_var(output_out)\n",
    "    encoded, _ = encoder(x)\n",
    "    decoder_in, s, w = decoder_initial(x.size(0))\n",
    "    out_list=[]\n",
    "    for j in range(y.size(1)): # for all sequences\n",
    "        if j==0:\n",
    "    #         out, s, w = decoder(input_idx=decoder_in, encoded=encoded,\n",
    "            out, s, w = decoder_debug(input_idx=decoder_in, encoded=encoded,\n",
    "                            encoded_idx=input_out, prev_state=s, \n",
    "                            weighted=w, order=j)\n",
    "        else:\n",
    "    #         tmp_out, s, w = decoder(input_idx=decoder_in, encoded=encoded,\n",
    "            tmp_out, s, w = decoder_debug(input_idx=decoder_in, encoded=encoded,\n",
    "                            encoded_idx=input_out, prev_state=s, \n",
    "                            weighted=w, order=j)\n",
    "            tmp_data = tmp_out.data\n",
    "            tmp_data[:,:,0].zero_()\n",
    "            tmp_out = Variable(tmp_data)\n",
    "            out = torch.cat([out,tmp_out],dim=1)\n",
    "    #     decoder_in = y[:,j] # train with ground truth\n",
    "        decoder_in = out[:,-1].max(1)[1].squeeze() # train with sequence outputs\n",
    "        out_list.append(out[:,-1].max(1)[1].squeeze().cpu().data.numpy())\n",
    "    # with open(save_dir,'a') as f:\n",
    "    #     out = np.hstack(tup=(out,iden))\n",
    "    #     f.write('\\n')\n",
    "    #     for line in out:\n",
    "    #         f.write(','.join([str(y_) for y_ in line])+'\\n')\n",
    "    out_list = np.array(out_list).transpose()\n",
    "    for i in range(len(input_out)): # for each sample in batch\n",
    "#         print(\"==============================================================\")\n",
    "        i_line = input_out[i]\n",
    "        i_out = [idx for idx in IdxToWords(i_line,i2w) if (idx!='<PAD>') & (idx!='<SOS>') & (idx!='<EOS>')]\n",
    "#         print(\"[INPUT]\")\n",
    "#         print(' '.join(i_out))\n",
    "        i_line = '[INPUT]\\n'+' '.join(i_out)\n",
    "\n",
    "        a_line = output_out[i]\n",
    "        a_out = [idx for idx in IdxToWords(a_line,i2w) if (idx!='<PAD>') & (idx!='<SOS>') & (idx!='<EOS>')]\n",
    "#         print(\"[GROUND OUTPUT]\")\n",
    "#         print(' '.join(a_out))\n",
    "        a_line = '[GROUND TRUTH]\\n# '+' '.join(a_out)\n",
    "\n",
    "        o_line = out_list[i]\n",
    "        o_out = []\n",
    "        for idx in o_line:\n",
    "            if idx==w2i['<EOS>']:\n",
    "                break\n",
    "            if (idx!=w2i['<PAD>']) & (idx!=w2i['<SOS>']):\n",
    "                o_out.append(idx)\n",
    "        o_out = IdxToWords(o_out,i2w)\n",
    "#         print(\"[PREDICTED OUTPUT]\")\n",
    "#         print(' '.join(o_out))\n",
    "        o_line = '[PREDICTED]\\n# '+' '.join(o_out)\n",
    "        print_line = '\\n'.join([i_line,a_line,o_line])\n",
    "        print_list.append(print_line)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with open('django-results.py','w') as f:\n",
    "    f.write('\\n\\n'.join(print_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda env:pytorch]",
   "language": "python",
   "name": "conda-env-pytorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
